# -*- coding: utf-8 -*-
"""Using Adversarially Learned Anomaly Detection
"""
# Author: Michiel Bongaerts (but not author of the ALAD method)
# Pytorch version Author: Jiaqi Li <jli77629@usc.edu>


import numpy as np
import pandas as pd

try:
    import torch
except ImportError:
    print('please install torch first')

import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

from .base import BaseDetector
from ..utils.utility import check_parameter


class ALAD(BaseDetector):
    """Adversarially Learned Anomaly Detection (ALAD). 
    Paper: https://arxiv.org/pdf/1812.02288.pdf

    See :cite:`zenati2018adversarially` for details.
    
    Parameters
    ----------
    output_activation : str, optional (default=None)
        Activation function to use for output layers for encoder and dector.

    activation_hidden_disc : str, optional (default='tanh')
        Activation function to use for hidden layers in discrimators.

    activation_hidden_gen : str, optional (default='tanh')
        Activation function to use for hidden layers in encoder and decoder
        (i.e. generator).

    epochs : int, optional (default=500)
        Number of epochs to train the model.

    batch_size : int, optional (default=32)
        Number of samples per gradient update.

    dropout_rate : float in (0., 1), optional (default=0.2)
        The dropout to be used across all layers.

    dec_layers : list, optional (default=[5,10,25])
        List that indicates the number of nodes per hidden layer for the d
        ecoder network.
        Thus, [10,10] indicates 2 hidden layers having each 10 nodes.

    enc_layers : list, optional (default=[25,10,5])
        List that indicates the number of nodes per hidden layer for the
        encoder network.
        Thus, [10,10] indicates 2 hidden layers having each 10 nodes.

    disc_xx_layers : list, optional (default=[25,10,5])
        List that indicates the number of nodes per hidden layer for
        discriminator_xx.
        Thus, [10,10] indicates 2 hidden layers having each 10 nodes.

    disc_zz_layers : list, optional (default=[25,10,5])
        List that indicates the number of nodes per hidden layer for
        discriminator_zz.
        Thus, [10,10] indicates 2 hidden layers having each 10 nodes.

    disc_xz_layers : list, optional (default=[25,10,5])
        List that indicates the number of nodes per hidden layer for
        discriminator_xz.
        Thus, [10,10] indicates 2 hidden layers having each 10 nodes.

    learning_rate_gen: float in (0., 1), optional (default=0.001)
        learning rate of training the encoder and decoder

    learning_rate_disc: float in (0., 1), optional (default=0.001)
        learning rate of training the discriminators

    add_recon_loss: bool optional (default=False)
        add an extra loss for encoder and decoder based on the reconstruction
        error

    lambda_recon_loss: float in (0., 1), optional (default=0.1)
        if ``add_recon_loss= True``, the reconstruction loss gets multiplied
        by ``lambda_recon_loss`` and added to the total loss for the generator
         (i.e. encoder and decoder).

    preprocessing : bool, optional (default=True)
        If True, apply standardization on the data.

    verbose : int, optional (default=1)
        Verbosity mode.
        - 0 = silent
        - 1 = progress bar

    contamination : float in (0., 0.5), optional (default=0.1)
        The amount of contamination of the data set, i.e.
        the proportion of outliers in the data set. When fitting this is used
        to define the threshold on the decision function.

    device : str or None, optional (default=None)
        The device to use for computation. If None, the default device will be used.
        Possible values include 'cpu' or 'gpu'. This parameter allows the user
        to specify the preferred device for running the model.
        
    Attributes
    ----------
    decision_scores_ : numpy array of shape (n_samples,)
        The outlier scores of the training data [0,1].
        The higher, the more abnormal. Outliers tend to have higher
        scores. This value is available once the detector is
        fitted.

    threshold_ : float
        The threshold is based on ``contamination``. It is the
        ``n_samples * contamination`` most abnormal samples in
        ``decision_scores_``. The threshold is calculated for generating
        binary outlier labels.

    labels_ : int, either 0 or 1
        The binary labels of the training data. 0 stands for inliers
        and 1 for outliers/anomalies. It is generated by applying
        ``threshold_`` on ``decision_scores_``.
    """

    def __init__(self, activation_hidden_gen='tanh',
                 activation_hidden_disc='tanh',
                 output_activation=None,
                 dropout_rate=0.2,
                 latent_dim=2,
                 dec_layers=[5, 10, 25],
                 enc_layers=[25, 10, 5],
                 disc_xx_layers=[25, 10, 5],
                 disc_zz_layers=[25, 10, 5],
                 disc_xz_layers=[25, 10, 5],
                 learning_rate_gen=0.0001, learning_rate_disc=0.0001,
                 add_recon_loss=False, lambda_recon_loss=0.1,
                 epochs=200,
                 verbose=0,
                 preprocessing=False,
                 add_disc_zz_loss=True, spectral_normalization=False,
                 batch_size=32, contamination=0.1, device=None):
        super(ALAD, self).__init__(contamination=contamination)

        self.device = device if device else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.activation_hidden_disc = activation_hidden_disc
        self.activation_hidden_gen = activation_hidden_gen
        self.output_activation = output_activation
        self.dropout_rate = dropout_rate
        self.latent_dim = latent_dim
        self.dec_layers = dec_layers
        self.enc_layers = enc_layers

        self.disc_xx_layers = disc_xx_layers
        self.disc_zz_layers = disc_zz_layers
        self.disc_xz_layers = disc_xz_layers

        self.add_recon_loss = add_recon_loss
        self.lambda_recon_loss = lambda_recon_loss
        self.add_disc_zz_loss = add_disc_zz_loss

        self.contamination = contamination
        self.epochs = epochs
        self.learning_rate_gen = learning_rate_gen
        self.learning_rate_disc = learning_rate_disc
        self.preprocessing = preprocessing
        self.batch_size = batch_size
        self.verbose = verbose
        self.spectral_normalization = spectral_normalization

        if self.spectral_normalization:
            try:
                import torch.nn.utils.spectral_norm as spectral_norm
                self.spectral_norm = spectral_norm
            except ImportError:
                print('Spectral normalization not available. '
                      'Install torch>=1.0.0.')
                self.spectral_normalization = False

        check_parameter(dropout_rate, 0, 1, param_name='dropout_rate',
                        include_left=True)

    def _build_model(self):
        def get_activation(name):
            if name == 'tanh':
                return nn.Tanh()
            elif name == 'sigmoid':
                return nn.Sigmoid()
            elif name == 'relu':
                return nn.ReLU()
            else:
                raise ValueError(
                    "Unsupported activation function: {}".format(name))

        # Create the decoder
        dec_layers = []
        input_dim = self.latent_dim
        for l_dim in self.dec_layers:
            dec_layers.append(nn.Linear(input_dim, l_dim))
            dec_layers.append(nn.Dropout(self.dropout_rate))
            dec_layers.append(get_activation(self.activation_hidden_gen))
            input_dim = l_dim
        dec_layers.append(nn.Linear(input_dim, self.n_features_))
        if self.output_activation:
            dec_layers.append(get_activation(self.output_activation))
        self.dec = nn.Sequential(*dec_layers).to(self.device)

        # Create the encoder
        enc_layers = []
        input_dim = self.n_features_
        for l_dim in self.enc_layers:
            enc_layers.append(nn.Linear(input_dim, l_dim))
            enc_layers.append(nn.Dropout(self.dropout_rate))
            enc_layers.append(get_activation(self.activation_hidden_gen))
            input_dim = l_dim
        enc_layers.append(nn.Linear(input_dim, self.latent_dim))
        if self.output_activation:
            enc_layers.append(get_activation(self.output_activation))
        self.enc = nn.Sequential(*enc_layers).to(self.device)

        # Create the discriminators
        def create_discriminator(layers, input_dim):
            disc_layers = []
            for l_dim in layers:
                disc_layers.append(nn.Linear(input_dim, l_dim))
                if self.spectral_normalization:
                    disc_layers[-1] = nn.utils.spectral_norm(disc_layers[-1])
                disc_layers.append(nn.Dropout(self.dropout_rate))
                disc_layers.append(get_activation(self.activation_hidden_disc))
                input_dim = l_dim
            disc_layers.append(nn.Linear(input_dim, 1))
            disc_layers.append(nn.Sigmoid())
            return nn.Sequential(*disc_layers).to(self.device)

        self.disc_xx = create_discriminator(self.disc_xx_layers,
                                            2 * self.n_features_)
        self.disc_zz = create_discriminator(self.disc_zz_layers,
                                            2 * self.latent_dim)
        self.disc_xz = create_discriminator(self.disc_xz_layers,
                                            self.n_features_ + self.latent_dim)

        # Optimizers
        self.opt_gen = optim.Adam(
            list(self.enc.parameters()) + list(self.dec.parameters()),
            lr=self.learning_rate_gen)
        self.opt_disc = optim.Adam(list(self.disc_xx.parameters()) + list(
            self.disc_xz.parameters()) + list(self.disc_zz.parameters()),
                                   lr=self.learning_rate_disc)

        self.hist_loss_disc = []
        self.hist_loss_gen = []

    def train_step(self, data):
        x_real, z_real = data

        x_real = torch.FloatTensor(x_real).to(self.device)
        z_real = torch.FloatTensor(z_real).to(self.device)

        self.opt_disc.zero_grad()
        x_gen = self.dec(z_real)
        z_gen = self.enc(x_real)

        out_true_xz = self.disc_xz(torch.cat((x_real, z_gen), dim=1))
        out_fake_xz = self.disc_xz(torch.cat((x_gen, z_real), dim=1))

        out_true_xx = self.disc_xx(torch.cat((x_real, x_real), dim=1))
        out_fake_xx = self.disc_xx(torch.cat((x_real, x_gen), dim=1))

        loss_dxz = nn.BCELoss()(out_true_xz,
                                torch.ones_like(out_true_xz)) + nn.BCELoss()(
            out_fake_xz, torch.zeros_like(out_fake_xz))
        loss_dxx = nn.BCELoss()(out_true_xx,
                                torch.ones_like(out_true_xx)) + nn.BCELoss()(
            out_fake_xx, torch.zeros_like(out_fake_xx))

        if self.add_disc_zz_loss:
            out_true_zz = self.disc_zz(torch.cat((z_real, z_real), dim=1))
            out_fake_zz = self.disc_zz(torch.cat((z_real, z_gen), dim=1))
            loss_dzz = nn.BCELoss()(out_true_zz, torch.ones_like(
                out_true_zz)) + nn.BCELoss()(out_fake_zz,
                                             torch.zeros_like(out_fake_zz))
            loss_disc = loss_dxz + loss_dzz + loss_dxx
        else:
            loss_disc = loss_dxz + loss_dxx

        loss_disc.backward()
        self.opt_disc.step()

        self.opt_gen.zero_grad()
        x_gen = self.dec(z_real)
        z_gen = self.enc(x_real)

        out_true_xz = self.disc_xz(torch.cat((x_real, z_gen), dim=1))
        out_fake_xz = self.disc_xz(torch.cat((x_gen, z_real), dim=1))

        out_true_xx = self.disc_xx(torch.cat((x_real, x_real), dim=1))
        out_fake_xx = self.disc_xx(torch.cat((x_real, x_gen), dim=1))

        loss_gexz = nn.BCELoss()(out_fake_xz,
                                 torch.ones_like(out_fake_xz)) + nn.BCELoss()(
            out_true_xz, torch.zeros_like(out_true_xz))
        loss_gexx = nn.BCELoss()(out_fake_xx,
                                 torch.ones_like(out_fake_xx)) + nn.BCELoss()(
            out_true_xx, torch.zeros_like(out_true_xx))

        if self.add_disc_zz_loss:
            out_true_zz = self.disc_zz(torch.cat((z_real, z_real), dim=1))
            out_fake_zz = self.disc_zz(torch.cat((z_real, z_gen), dim=1))
            loss_gezz = nn.BCELoss()(out_fake_zz, torch.ones_like(
                out_fake_zz)) + nn.BCELoss()(out_true_zz,
                                             torch.zeros_like(out_true_zz))
            cycle_consistency = loss_gezz + loss_gexx
            loss_gen = loss_gexz + cycle_consistency
        else:
            cycle_consistency = loss_gexx
            loss_gen = loss_gexz + cycle_consistency

        if self.add_recon_loss:
            x_recon = self.dec(self.enc(x_real))
            loss_recon = torch.mean((x_real - x_recon) ** 2)
            loss_gen += loss_recon * self.lambda_recon_loss

        loss_gen.backward()
        self.opt_gen.step()

        self.hist_loss_disc.append(loss_disc.item())
        self.hist_loss_gen.append(loss_gen.item())

    def fit(self, X, y=None, noise_std=0.1):
        """Fit detector. y is ignored in unsupervised methods.
        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The input samples.
        y : Ignored
            Not used, present for API consistency by convention.
        Returns
        -------
        self : object
            Fitted estimator.
        """
        # validate inputs X and y (optional)
        X = check_array(X)
        self._set_n_classes(y)

        # Get number of sampels and features from train set
        self.n_samples_, self.n_features_ = X.shape[0], X.shape[1]
        self._build_model()

        # Apply data scaling or not
        if self.preprocessing:
            self.scaler_ = StandardScaler()
            X_norm = self.scaler_.fit_transform(X)
        else:
            X_norm = np.copy(X)

        for n in range(self.epochs):
            if n % 50 == 0 and n != 0 and self.verbose == 1:
                print(f'Train iter: {n}')

            # Shuffle train 
            np.random.shuffle(X_norm)

            X_train_sel = X_norm[
                          :min(self.batch_size, self.n_samples_)].astype(
                np.float32)
            latent_noise = np.random.normal(0, 1, (
                X_train_sel.shape[0], self.latent_dim))
            X_train_sel += np.random.normal(0, noise_std,
                                            size=X_train_sel.shape)
            self.train_step((X_train_sel, latent_noise))

        if self.preprocessing:
            X_norm = self.scaler_.transform(X)
        else:
            X_norm = np.copy(X)

        pred_scores = self.get_outlier_scores(X_norm)
        self.decision_scores_ = pred_scores
        self._process_decision_scores()
        return self

    def train_more(self, X, epochs=100, noise_std=0.1):
        """This function allows the researcher to perform extra training
        instead of the fixed number determined
        by the fit() function.
        """
        # fit() should have been called first
        check_is_fitted(self, ['decision_scores_'])

        # Apply data scaling or not
        if self.preprocessing:
            X_norm = self.scaler_.transform(X)
        else:
            X_norm = np.copy(X)

        for n in range(epochs):
            if n % 50 == 0 and n != 0 and self.verbose == 1:
                print(f'Train iter: {n}')

            # Shuffle train 
            np.random.shuffle(X_norm)

            X_train_sel = X_norm[
                          :min(self.batch_size, self.n_samples_)].astype(
                np.float32)
            latent_noise = np.random.normal(0, 1, (
                X_train_sel.shape[0], self.latent_dim))
            X_train_sel += np.random.normal(0, noise_std,
                                            size=X_train_sel.shape)
            self.train_step((X_train_sel, latent_noise))

        if self.preprocessing:
            X_norm = self.scaler_.transform(X)
        else:
            X_norm = np.copy(X)

        pred_scores = self.get_outlier_scores(X_norm)
        self.decision_scores_ = pred_scores
        self._process_decision_scores()
        return self

    def get_outlier_scores(self, X_norm):
        X_norm = torch.FloatTensor(X_norm).to(self.device)
        X_enc = self.enc(X_norm).detach().cpu().numpy()
        X_enc_gen = self.dec(
            torch.FloatTensor(X_enc).to(self.device)).detach().cpu().numpy()

        out_true_xx = self.disc_xx(
            torch.cat((X_norm, X_norm), dim=1)).detach().cpu().numpy()
        out_fake_xx = self.disc_xx(
            torch.cat((X_norm, torch.FloatTensor(X_enc_gen).to(self.device)),
                      dim=1)).detach().cpu().numpy()

        outlier_scores = np.mean(np.abs((out_true_xx - out_fake_xx) ** 2),
                                 axis=1)
        return outlier_scores

    def decision_function(self, X):
        """Predict raw anomaly score of X using the fitted detector.
        The anomaly score of an input sample is computed based on different
        detector algorithms. For consistency, outliers are assigned with
        larger anomaly scores.
        Parameters
        ----------
        X : numpy array of shape (n_samples, n_features)
            The training input samples. Sparse matrices are accepted only
            if they are supported by the base estimator.
        Returns
        -------
        anomaly_scores : numpy array of shape (n_samples,)
            The anomaly score of the input samples.
        """
        check_is_fitted(self, ['decision_scores_'])
        X = check_array(X)

        if self.preprocessing:
            X_norm = self.scaler_.transform(X)
        else:
            X_norm = np.copy(X)

        X_norm = torch.FloatTensor(X_norm).to(self.device)
        pred_scores = self.get_outlier_scores(X_norm.cpu().numpy())
        return pred_scores

    def plot_learning_curves(self, start_ind=0, window_smoothening=10):
        fig = plt.figure(figsize=(12, 5))

        l_gen = pd.Series(self.hist_loss_gen[start_ind:]).rolling(
            window=window_smoothening).mean()
        l_disc = pd.Series(self.hist_loss_disc[start_ind:]).rolling(
            window=window_smoothening).mean()

        ax = fig.add_subplot(1, 2, 1)
        ax.plot(range(len(l_gen)), l_gen)
        ax.set_title('Generator')
        ax.set_ylabel('Loss')
        ax.set_xlabel('Iter')

        ax = fig.add_subplot(1, 2, 2)
        ax.plot(range(len(l_disc)), l_disc)
        ax.set_title('Discriminator(s)')
        ax.set_ylabel('Loss')
        ax.set_xlabel('Iter')

        plt.show()
